{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bNkqsp45yZBp"
   },
   "source": [
    "<h1>Training and Benchmarking a Model Using Model Zoo</h1>\n",
    "\n",
    "This notebook provides a step-by-step guide on how to use the **STM32AI model zoo** to train, quantize, and benchmark an image classification model. The resulting model can be deployed on STM32 targets, making it ideal for edge computing applications. This notebook can be used in Jupyter Notebook environments, providing flexibility for users who prefer different development environments."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S169jHOt3qLY"
   },
   "source": [
    "## License of the Jupyter Notebook\n",
    "\n",
    "This software component is licensed by ST under BSD-3-Clause license,\n",
    "the \"License\";\n",
    "\n",
    "You may not use this file except in compliance with the\n",
    "License.\n",
    "\n",
    "You may obtain a copy of the License at: https://opensource.org/licenses/BSD-3-Clause\n",
    "\n",
    "Copyright (c) 2023 STMicroelectronics. All rights reserved"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KyAK7KjSix-T"
   },
   "source": [
    "<div style=\"border-bottom: 3px solid #273B5F\">\n",
    "<h2>Table of content</h2>\n",
    "<ul style=\"list-style-type: none\">\n",
    "<li><a href=\"#install\">1. Install necessary packages</a>\n",
    "<li><a href=\"#config\">2. Configure environment variables to access STM32Cube.AI Developer Cloud Services</a></li>\n",
    "<li><a href=\"#upload\">3. Upload the dataset</h2> </a></li>\n",
    "<li><a href=\"#training\">4. Training and Benchmarking the Model</a></li>\n",
    "<li><a href=\"#results\">5. Results</a></li>\n",
    "  </ul>\n",
    "</ul>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NbNXaOtk3qLY"
   },
   "source": [
    "<div id=\"install\">\n",
    "    <h2>1. Install necessary packages</h2>\n",
    "</div>\n",
    "\n",
    "To get started, upload the model zoo package and clone the repository using the following command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NFBa0MB26boA",
    "outputId": "69088c5b-e0be-4a6f-c1e5-6e35b5989d01"
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/STMicroelectronics/stm32ai-modelzoo_services.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gJUip3GB67e3"
   },
   "source": [
    "Or, you can upload a lighter version of STM32 model zoo by following these steps:\n",
    "- On your local PC clone STM32AI model zoo git using the following command:\n",
    "```\n",
    "git clone https://github.com/STMicroelectronics/stm32ai-modelzoo_services.git\n",
    "```\n",
    "- Delete the .git directory.\n",
    "\n",
    "- For image classification use-case, you can keep only the folders 'image_classification' and 'common', as well as the file 'requirements.txt', then delete the rest.\n",
    "\n",
    "- Zip the repository as stm32ai-modelzoo_services.zip, and upload **stm32-modelzoo_services.zip** in your workspace.\n",
    "\n",
    "- Then uncomment and run the cell below to unzip the folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tCixcLxV7zr3"
   },
   "outputs": [],
   "source": [
    "# import zipfile\n",
    "# with zipfile.ZipFile('stm32ai-modelzoo_services.zip', 'r') as zip_ref:\n",
    "#     zip_ref.extractall('')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RpPHPEgW-brw"
   },
   "source": [
    "Next, run the following command to install the required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "g9DQwE8G-hYv",
    "outputId": "bdd17977-dd65-4fcf-ef67-3f5274979b4f"
   },
   "outputs": [],
   "source": [
    "!pip install -r stm32ai-modelzoo_services/requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PFXdwWJHBin-"
   },
   "source": [
    "<div id=\"config\">\n",
    "    <h2>2. Configure environment variables to access STM32Cube.AI Developer Cloud Services</h2>\n",
    "</div>\n",
    "Set environment variables with your credentials to acces STM32Cube.AI Developer Cloud Services.\n",
    "\n",
    "If you don't have an account yet go to: https://stedgeai-dc.st.com/home and click on sign in to create an account.\n",
    "\n",
    "Then set the environment variables below with your credentials.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u-TXBTDpB6WT"
   },
   "outputs": [],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "os.environ['stmai_username'] = 'xxx.yyy@st.com'\n",
    "print('Enter you password')\n",
    "password = getpass.getpass()\n",
    "os.environ['stmai_password'] = password"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "c8eeXs6h9S6s"
   },
   "source": [
    "<div id=\"upload\">\n",
    "    <h2>3. Upload the dataset</h2>\n",
    "</div>\n",
    "The dataset can be uploaded as a zip archive named **dataset.zip** under the directory 'workspace/stm32ai-modelzoo_services/image_classification/datasets'.\n",
    "\n",
    "The zip file shall contain a directory named \"dataset\" with one sub-directory per category, with images inside as below:\n",
    "\n",
    "```bash\n",
    "dataset_root_directory/\n",
    "   class_a/\n",
    "      a_image_1.jpg\n",
    "      a_image_2.jpg\n",
    "   class_b/\n",
    "      b_image_1.jpg\n",
    "      b_image_2.jpg\n",
    "```\n",
    "Other dataset formats are not supported. The only exceptions are the Cifar10/Cifar100 datasets. For these datasets, the official format in batches is supported.\n",
    "\n",
    "The split between training and validation sets is done automatically by the scripts. However, it is also possible to upload specific training, validation, and test sets by defining specific paths in the user_config.yaml file.\n",
    "\n",
    "In this tutorial we are going to use the flower dataset that can be downloaded directly from tensorflow repository: https://www.tensorflow.org/datasets/catalog/tf_flowers (Creative Commons By-Attribution License 2.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ip04ofreCrYC",
    "outputId": "1655cf38-0178-4d00-ca0f-756a040b5fb9"
   },
   "outputs": [],
   "source": [
    "dataset_name = 'tf_flowers' #@param [\"custom\", \"tf_flowers\"]\n",
    "%cd stm32ai-modelzoo_services/image_classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 387
    },
    "id": "O9Zv5TFhBz6N",
    "outputId": "f256deab-4d1f-4448-a9e2-1ab5e4cc9a89"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg\n",
    "import numpy as np\n",
    "import os\n",
    "import zipfile\n",
    "\n",
    "# Define the path to the dataset directory\n",
    "if dataset_name == 'tf_flowers':\n",
    "    path = 'datasets/flower_photos'\n",
    "    !wget http://download.tensorflow.org/example_images/flower_photos.tgz -P datasets\n",
    "    !tar -xf datasets/flower_photos.tgz -C datasets\n",
    "else:\n",
    "    path = 'datasets/dataset'\n",
    "    with zipfile.ZipFile('datasets/dataset.zip', 'r') as zip_ref:\n",
    "        zip_ref.extractall('datasets')\n",
    "\n",
    "# Get the list of class names\n",
    "class_names = sorted([name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))])\n",
    "num_classes = len(class_names)\n",
    "print(f\"Classes: {class_names}\")\n",
    "print(f\"Introducing samples from each class...\")\n",
    "\n",
    "# Print a photo from each class\n",
    "fig, axs = plt.subplots(1, num_classes, figsize=(4*num_classes, 4))\n",
    "for i, class_name in enumerate(class_names):\n",
    "    class_path = os.path.join(path, class_name)\n",
    "    image_files = [f for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f)) and f.endswith('.jpg')]\n",
    "    if len(image_files) == 0:\n",
    "        img = np.zeros((224, 224, 3))\n",
    "    else:\n",
    "        img_path = os.path.join(class_path, random.choice(image_files))\n",
    "        img = mpimg.imread(img_path)\n",
    "    axs[i].imshow(img)\n",
    "    axs[i].set_title(class_name)\n",
    "    axs[i].axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wNCJANEqix-W"
   },
   "source": [
    "<div id=\"training\">\n",
    "    <h2>4. Training and Benchmarking the Model</h2>\n",
    "</div>\n",
    "\n",
    "The STM32 model zoo is an invaluable resource that provides a wide range of use cases, including image classification, object detection, audio event detection, hand posture, and human activity recognition. The model zoo offers various services, including training, evaluation, prediction, deployment, quantization, benchmarking, and chained services. These services, such as chain_tbqeb, chain_tqe, chain_eqe, chain_qb, chain_eqeb, and chain_qd, are thoroughly explained in their respective readmes.\n",
    "\n",
    "In this section, we will demonstrate how to train, quantize, evaluate, and benchmark a classification model using the chain_tbqeb service. We will use the MobileNet v2 0.35 model from the model zoo as an example, but you can also use your own custom model. To accomplish this, we will use the `user_config.yaml` file as a configuration file to specify the service and the set of configuration parameters, such as the model, dataset, number of epochs, and preprocessing parameters. Please feel free to review and adjust the training parameters as needed.\n",
    "\n",
    "For a custom dataset, in the dataset section, modify:\n",
    "*   the name and class_names accordingly.\n",
    "*   training path: `training_path: ../datasets/dataset`\n",
    "\n",
    "Then, you can tune the other parameters and save the file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 636
    },
    "id": "yodEvVHgEG2k",
    "outputId": "563a52d8-09e0-47ad-de0d-55702c749f87"
   },
   "outputs": [],
   "source": [
    "%cd src\n",
    "%run stm32ai_main.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eYcsLaTSGi7J"
   },
   "source": [
    "<div id=\"results\">\n",
    "    <h2>5. Results</h2>\n",
    "</div>\n",
    "The trained and quantized models, along with any artifacts, plots, and figures related to the experiments, can be found in the 'workspace/stm32ai-modelzoo_services/image_classification/src/experiments_outputs' directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "id": "P_P-e9nc6jdl",
    "outputId": "404e11d7-965a-4edb-a52e-ea9df5cb7dff"
   },
   "outputs": [],
   "source": [
    "import shutil\n",
    "shutil.make_archive('experiments_outputs', 'zip', 'experiments_outputs')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "3ca9c95fb3295dba58147778a3f6149a36aba268806f86b68ae4a365fcdcc5ff"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
